"""
https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html
"""
import torch as tch
import torchvision as tv
import matplotlib.pyplot as plt
import os
import sys

from torch.utils.data import DataLoader

data_path = '../../../../large_data/DL2/pt/fashion_mnist'
print(os.path.exists(data_path))

training_data = tv.datasets.FashionMNIST(
    root=data_path,
    train=True,
    download=True,
    transform=tv.transforms.ToTensor(),
)

batch_size = 64

train_dataloader = DataLoader(training_data, batch_size=batch_size)

i = -1
for X, y in train_dataloader:
    i += 1
    if i > 5:
        break
    print(f'Shape of X [N, C, H, W]: {X.shape}')
    print(f'Shape of y: {y.shape}, {y.dtype}')

